import time
import os

import numpy as np
import torch
import gym

from typing import Optional, Dict, List, Tuple
from tqdm import tqdm
from tqdm.auto import trange  # noqa
from collections import deque

# model-based policy trainer
class MBPolicyTrainer:
    def __init__(
        self,
        args,
        policy,
        policy_trainer,
        real_buffer,
        fake_buffer,
        logger,
        rollout_setting: Tuple[int, int, int],
        batch_size: int = 256,
        real_ratio: float = 0.05,
        use_mobile: bool = False,
    ) -> None:
        self.args = args
        self.policy = policy
        self.policy_trainer = policy_trainer
        self.real_buffer = real_buffer
        self.fake_buffer = fake_buffer
        self.logger = logger

        self._rollout_freq, self._rollout_batch_size, \
            self._rollout_length = rollout_setting
        self._batch_size = batch_size
        self._real_ratio = real_ratio
        self.use_mobile = use_mobile

    def train(self) -> Dict[str, float]:
        best_reward = -np.inf
        best_cost = np.inf
        best_idx = 0

        for step in trange(self.args.update_steps, desc="Training"):
            if step % self._rollout_freq == 0:
                self.policy.post_update_fn()
                init_obss = self.real_buffer.sample(self._rollout_batch_size)["observations"].cpu().numpy()
                rollout_transitions, rollout_info = self.policy.rollout(init_obss, self._rollout_length)
                self.fake_buffer.add_batch(**rollout_transitions)
                self.logger.store(tab="model_rollout", rollout_reward_mean=rollout_info["reward_mean"], rollout_cost_mean=rollout_info["cost_mean"])
                self.policy.pre_update_fn()

            real_sample_size = int(self._batch_size * self._real_ratio)
            fake_sample_size = self._batch_size - real_sample_size
            real_batch = self.real_buffer.sample(batch_size=real_sample_size)
            fake_batch = self.fake_buffer.sample(batch_size=fake_sample_size)
            # print(f"REAL SAMPLE SIZE {real_sample_size}, FAKE SAMPLE SIZE {fake_sample_size}")
            # print(real_batch["observations"].shape, fake_batch["observations"].shape)
            mix_batch = {k: torch.cat([real_batch[k], fake_batch[k]], 0) for k in real_batch.keys()}
            observations, next_observations, actions, rewards, costs, done = mix_batch["observations"],\
                  mix_batch["next_observations"], mix_batch["actions"], mix_batch["rewards"], mix_batch["costs"], mix_batch["terminals"]
            # real_weights = torch.ones_like(rewards[:real_sample_size], dtype=rewards.dtype, device=rewards.device)
            # fake_weights = torch.zeros_like(rewards[real_sample_size:], dtype=rewards.dtype, device=rewards.device)
            # weights = torch.cat([real_weights, fake_weights], 0)
            real_weights = torch.ones_like(real_batch["costs"], dtype=costs.dtype, device=costs.device)
            fake_weights = torch.zeros_like(fake_batch["costs"], dtype=costs.dtype, device=costs.device)
            weights = torch.cat([real_weights, fake_weights], 0)

            # print(observations.shape)
            if self.use_mobile:
                real_batch_length = real_batch["observations"].shape[0]
                with torch.no_grad():
                    penalty_r = self.policy.compute_lcb(observations, actions)
                    penalty_c = self.policy.compute_cost_lcb(observations, actions)
                    penalty_r[:real_batch_length] = 0.0
                    penalty_c[:real_batch_length] = 0.0
                    if self.policy.use_reward_penalty:
                        rewards = rewards - self.policy.dynamics._penalty_coef * penalty_r
                    if self.policy.use_cost_penalty:
                        costs = costs + self.policy.dynamics._penalty_coef * penalty_c
                    if self.policy.binary_cost_penalty:
                        costs = torch.where(costs<0.5, torch.zeros_like(costs, dtype=costs.dtype, device=costs.device), torch.ones_like(costs, dtype=costs.dtype, device=costs.device))
                    

            self.policy_trainer.train_one_step(observations, next_observations, actions, rewards, costs,
                               done, weights)
            
            if (step + 1) % self.args.eval_every == 0 or step == self.args.update_steps - 1:
                ret, cost, length = self.policy_trainer.evaluate(self.args.eval_episodes)
                self.logger.store(tab="eval", Cost=cost, Reward=ret, Length=length)

                # save the current weight
                self.logger.save_checkpoint()
                # save the best weight
                if cost < best_cost or (cost == best_cost and ret > best_reward):
                    best_cost = cost
                    best_reward = ret
                    best_idx = step
                    self.logger.save_checkpoint(suffix="best")

                self.logger.store(tab="train", best_idx=best_idx)
                self.logger.write(step, display=False)
            else:
                self.logger.write_without_reset(step)